import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from functools import partial
from timm.layers import DropPath
# from flash_cosine_sim_attention import flash_cosine_sim_attention
import math
try:
    from .xshared_modules import RelativePositionBias, ContinuousPositionBias1D, MLP
except:
    from xshared_modules import RelativePositionBias, ContinuousPositionBias1D, MLP

def build_time_block(params):
    """
    Builds a time block from the parameter file. 
    """
    if params.time_type == 'attention':
        return partial(AttentionBlock, params.embed_dim, params.num_heads, bias_type=params.bias_type)
    else:
        raise NotImplementedError

# class InstanceNormNd(nn.Module):
#     def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
#         super().__init__()
#         self.num_features = num_features
#         self.eps = eps
#         self.momentum = momentum
#         self.affine = affine
#         self.track_running_stats = track_running_stats

#         if affine:
#             self.weight = nn.Parameter(torch.ones(num_features))
#             self.bias = nn.Parameter(torch.zeros(num_features))
#         else:
#             self.register_parameter("weight", None)
#             self.register_parameter("bias", None)

#         if track_running_stats:
#             self.register_buffer("running_mean", torch.zeros(num_features))
#             self.register_buffer("running_var", torch.ones(num_features))
#         else:
#             self.register_buffer("running_mean", None)
#             self.register_buffer("running_var", None)

#     def forward(self, x):
#         N, C = x.shape[:2]
#         spatial_shape = x.shape[2:]
#         x_flat = x.view(N, C, -1)

#         if not self.training and self.track_running_stats:
#             with torch.no_grad():
#                 x_norm = F.instance_norm(
#                     x_flat,
#                     running_mean=self.running_mean,
#                     running_var=self.running_var,
#                     weight=self.weight,
#                     bias=self.bias,
#                     use_input_stats=False,
#                     momentum=self.momentum,
#                     eps=self.eps
#                 )
#         else:
#             x_norm = F.instance_norm(
#                 x_flat,
#                 running_mean=self.running_mean,
#                 running_var=self.running_var,
#                 weight=self.weight,
#                 bias=self.bias,
#                 use_input_stats=True,
#                 momentum=self.momentum,
#                 eps=self.eps
#             )

#         return x_norm.view(N, C, *spatial_shape)

class InstanceNormNd(nn.Module):          # drop-in replacement
    def __init__(self, num_channels, eps=1e-5, affine=True):
        super().__init__()
        self.norm = nn.GroupNorm(
            num_groups=num_channels,      # ← 채널당 1 그룹 ⇒ 인스턴스노름과 동일
            num_channels=num_channels,
            eps=eps,
            affine=affine,
        )

    def forward(self, x):
        return self.norm(x)

    

class AttentionBlock(nn.Module):
    def __init__(self, hidden_dim=768, num_heads=12, drop_path=0, layer_scale_init_value=1e-6, bias_type='rel'):
        super().__init__()
        self.num_heads = num_heads
        self.norm1 = InstanceNormNd(hidden_dim, affine=True)
        self.norm2 = InstanceNormNd(hidden_dim, affine=True)
        # self.norm1 = nn.InstanceNorm1d(hidden_dim, affine=True)
        # self.norm2 = nn.InstanceNorm1d(hidden_dim, affine=True)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((hidden_dim)), 
                            requires_grad=True) if layer_scale_init_value > 0 else None
        # self.input_head = nn.Conv2d(hidden_dim, 3*hidden_dim, 1)
        # self.output_head = nn.Conv2d(hidden_dim, hidden_dim, 1)
        self.input_head = nn.Linear(hidden_dim, 3*hidden_dim)
        self.output_head = nn.Linear(hidden_dim, hidden_dim)
        self.qnorm = nn.LayerNorm(hidden_dim//num_heads)
        self.knorm = nn.LayerNorm(hidden_dim//num_heads)
        if bias_type == 'none':
            self.rel_pos_bias = lambda x, y: None
        elif bias_type == 'continuous':
            self.rel_pos_bias = ContinuousPositionBias1D(n_heads=num_heads)
        else:
            self.rel_pos_bias = RelativePositionBias(n_heads=num_heads)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x_list):
        new_x_list = []
        for x in x_list:
            # input is t x b x c x h x w 
            T, B, C, *H = x.shape
            D = len(H)
            axes = {f"s{i}":v for i,v in enumerate(H)}
            init = " ".join(axes.keys())
            input = x.clone()
            # Rearrange and prenorm
            x = rearrange(x, 't b c ... -> (t b) c ...')

            # shape = x.shape
            # x = rearrange(x, 'b c ... -> b c (...)')
            x = self.norm1(x)
            # x = x.view(*shape)

            x = rearrange(x, 'b c ... -> b ... c')
            x = self.input_head(x) # Q, K, V projections
            x = rearrange(x, 'b ... c -> b c ...')
            # Rearrange for attention
            x = rearrange(x, f'(t b) (he c) {init} ->  (b {init}) he t c', t=T, he=self.num_heads)
            q, k, v = x.tensor_split(3, dim=-1)
            q, k = self.qnorm(q), self.knorm(k)
            rel_pos_bias = self.rel_pos_bias(T, T)
            if rel_pos_bias is not None:
                x = F.scaled_dot_product_attention(q, k, v, attn_mask=rel_pos_bias) 
            else:
                x = F.scaled_dot_product_attention(q.contiguous(), k.contiguous(), v.contiguous())
            # Rearrange after attention
            x = rearrange(x, f'(b {init}) he t c -> (t b) (he c) {init}', **axes)

            # shape = x.shape
            # x = rearrange(x, "b c ... -> b c (...)")
            x = self.norm2(x) 
            # x = x.view(*shape)

            x = rearrange(x, 'b c ... -> b ... c')
            x = self.output_head(x)
            x = rearrange(x, 'b ... c -> b c ...')
            x = rearrange(x, '(t b) c ... -> t b c ...', t=T)
            gamma = self.gamma.view(1, 1, -1, *([1]*D))
            output = self.drop_path(x*gamma) + input
            new_x_list.append(output)
        return new_x_list
